from __future__ import print_function, absolute_import

import json
# import moxing as mox
import os.path as osp
import shutil
import sys

import torch
from torch.nn import Parameter

from config import get_args
from .distributed_ops import is_main_process
from .osutils import mkdir_if_missing

global_args = get_args(sys.argv[1:])

if global_args.run_on_remote:
    import moxing as mox


def read_json(fpath):
    with open(fpath, 'r') as f:
        obj = json.load(f)
    return obj


def write_json(obj, fpath):
    mkdir_if_missing(osp.dirname(fpath))
    with open(fpath, 'w') as f:
        json.dump(obj, f, indent=4, separators=(',', ': '))


def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'):
    print('=> saving checkpoint ', fpath)
    if global_args.run_on_remote:
        dir_name = osp.dirname(fpath)
        if not mox.file.exists(dir_name):
            mox.file.make_dirs(dir_name)
            print('=> makding dir ', dir_name)
        local_path = "local_checkpoint.pth.tar"
        torch.save(state, local_path)
        mox.file.copy(local_path, fpath)
        if is_best:
            mox.file.copy(local_path, osp.join(dir_name, 'model_best.pth.tar'))
    else:
        if is_main_process():
            mkdir_if_missing(osp.dirname(fpath))
            torch.save(state, fpath)
            if is_best:
                shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar'))


def load_checkpoint(fpath, device):
    if global_args.run_on_remote:
        mox.file.shift('os', 'mox')
        checkpoint = torch.load(fpath, map_location=device)
        print("=> Loaded checkpoint '{}'".format(fpath))
        return checkpoint
    else:
        load_path = fpath

        if osp.isfile(load_path):
            checkpoint = torch.load(load_path, map_location=device)
            print("=> Loaded checkpoint '{}'".format(load_path))
            return checkpoint
        else:
            raise ValueError("=> No checkpoint found at '{}'".format(load_path))


def copy_state_dict(state_dict, model, strip=None):
    tgt_state = model.state_dict()
    copied_names = set()
    for name, param in state_dict.items():
        if strip is not None and name.startswith(strip):
            name = name[len(strip):]
        if name not in tgt_state:
            continue
        if isinstance(param, Parameter):
            param = param.data
        if param.size() != tgt_state[name].size():
            print('mismatch:', name, param.size(), tgt_state[name].size())
            continue
        tgt_state[name].copy_(param)
        copied_names.add(name)

    missing = set(tgt_state.keys()) - copied_names
    if len(missing) > 0:
        print("missing keys in state_dict:", missing)

    return model
